set.seed(66)


energy_func = function(x) return(-log(0.4 * dnorm(x, mean=-3, sd=0.7) + 0.6 * dnorm(x, mean=2, sd=0.5)))

# You may try more iterations obtain stable results (otherwise the bias-corrected swaps tend to yield a large variance).
total = 3000000
lr = 0.003
T_high = 10
T_low = 1
thinning = 10000


# this is a baseline to show gradient Langevin dynamics (GLD) is slow in the Gaussian mixture example
GLD = function() {
    samples_GLD = c()
    x_low = 0
    for (iters in 1:(total*2)) {
        x_low  = x_low  - lr * numDeriv::grad(energy_func, x_low)  + sqrt(2 * lr * T_low)  * rnorm(1, 0, 1)
        
        if (iters %% (thinning*2) == 0) {samples_GLD = c(samples_GLD, x_low)}
        if (iters %% (2*total / 20) == 0) {print(paste("Iterations", iters, '/', total))}
    }
    return(samples_GLD)
}

# obtain GLD samples 
samples_GLD = GLD()

non_reversible_PT = function(window=1, energy_sd=0) {
    noisy_func = function(x) return(energy_func(x) + energy_sd * rnorm(1, 0, 1))
    samples_PT = c()
    x_low = 0
    x_high = 0
    gate = 1
    correction = (1 / T_high - 1 / T_low) * energy_sd^2
    for (iters in 1:total) {
        if (iters %% window == 0) {
            gate = 1
        }
        x_low  = x_low  - lr * numDeriv::grad(energy_func, x_low)  + sqrt(2 * lr * T_low)  * rnorm(1, 0, 1)
        x_high = x_high - lr * numDeriv::grad(energy_func, x_high) + sqrt(2 * lr * T_high) * rnorm(1, 0, 1)
        
        swap_rate = min(1, exp((1 / T_high - 1 / T_low) * (noisy_func(x_high) - noisy_func(x_low) - correction)))
        
        if ((runif(1) < swap_rate) & (1 == (floor(iters / window) %% 2)) & (gate == 1)) {
            tmp = x_low
            x_low = x_high
            x_high = tmp
            gate = 0
        }
        if (iters %% thinning == 0) {samples_PT = c(samples_PT, x_low)}
        if (iters %% (total / 20) == 0) {
            print(paste("Iterations", iters, '/', total))
        }
        
    }
    return(samples_PT)
}

window = 5

samples_PT_std0 = non_reversible_PT(window=window, energy_sd=0)
samples_PT_std1 = non_reversible_PT(window=window, energy_sd=1)
samples_PT_std2 = non_reversible_PT(window=window, energy_sd=2)
samples_PT_std3 = non_reversible_PT(window=window, energy_sd=3)

real_samples = c(rnorm(length(samples_PT_std0)*0.4, mean=-3, sd=0.7), rnorm(length(samples_PT_std0)*0.6, mean=2, sd=0.5))



library(ggplot2)


wdata = data.frame(
        Type = factor(rep(c("Ground truth", "GLD (2x iters)", "PT (sd=0)", "PT (sd=1)", "PT (sd=2)", "PT (sd=3)"), each=length(real_samples))),
        weight = c(real_samples, samples_GLD, samples_PT_std0, samples_PT_std1, samples_PT_std2, samples_PT_std3))

p=ggplot(wdata, aes(x = weight)) +
    stat_density(aes(x=weight, colour=Type, linetype=Type), size=2, geom="line", position="identity") +
    scale_linetype_manual(values=c("longdash", "solid", "longdash", "longdash", "longdash", "solid")) +
    scale_color_manual(values = c("red", "black", "cyan1", "cyan3", "cyan4","darkblue")) +
    scale_x_continuous(name="X") +
    ggtitle(paste0("PT with Window=", window)) +
    scale_y_continuous(name="Density", limits=c(0, 0.3)) +
    theme(
        legend.position = c(0.3, 0.8), 
        legend.text = element_text(colour="grey15", size=24),
        legend.key.size = unit(2,"line"),
        legend.key = element_blank(),
        legend.background=element_rect(fill=alpha('grey', 0.5)),
        axis.title=element_text(size=24),
        axis.text.y = element_text(size=24),
        axis.text.x = element_text(size=24),
        plot.title = element_text(size=32)
        )
p


